package com.xiaoxin.executor.batch.partitioner;

import org.apache.commons.lang3.StringUtils;
import org.springframework.batch.core.partition.support.Partitioner;
import org.springframework.batch.item.ExecutionContext;
import org.springframework.jdbc.core.JdbcTemplate;

import java.math.BigInteger;
import java.text.MessageFormat;
import java.util.HashMap;
import java.util.Map;
import javax.sql.DataSource;

/**
 * 查询分区
 *
 * @author ZhangXX
 * 2023-01-09 14:44
 */
public class DataBasePartitioner implements Partitioner {

    private static final String _MINRECORD = "_minRecord";

    private static final String _MAXRECORD = "_maxRecord";

    private static final String MIN_SELECT_PATTERN = "select min({0}) from {1} {2}";

    private static final String MAX_SELECT_PATTERN = "select max({0}) from {1} {2}";

    private static JdbcTemplate jdbcTemplate;

    private DataSource dataSource;

    private String table;

    private String column;

    private String whereCondition;

    @Override
    public Map<String, ExecutionContext> partition(int gridSize) {
        validateAndInit();
        Map<String, ExecutionContext> resultMap = new HashMap<String, ExecutionContext>();
        BigInteger min = jdbcTemplate.queryForObject(MessageFormat.format(MIN_SELECT_PATTERN, new Object[]{column, table, whereCondition}), BigInteger.class);
        BigInteger max = jdbcTemplate.queryForObject(MessageFormat.format(MAX_SELECT_PATTERN, new Object[]{column, table, whereCondition}), BigInteger.class);

        if (min == null) {
            return resultMap;
        }
        BigInteger targetSize = max.add(min.multiply(new BigInteger("-1"))).divide(new BigInteger(gridSize + "")).add(new BigInteger("1"));
        int number = 0;
        BigInteger start = min;
        BigInteger end = start.add(targetSize).add(new BigInteger("-1"));
        while (start.compareTo(max) <= 0) {
            ExecutionContext context = new ExecutionContext();
            if (end.compareTo(max) >= 0) {
                end = max;
            }
            context.putString(_MINRECORD, start.toString());
            context.putString(_MAXRECORD, end.toString());
            start = start.add(targetSize);
            end = end.add(targetSize);
            resultMap.put("partition" + (number++), context);
        }

        return resultMap;
    }

    public void validateAndInit() {
        if (StringUtils.isEmpty(table)) {
            throw new IllegalArgumentException("table cannot be null");
        }
        if (StringUtils.isEmpty(column)) {
            throw new IllegalArgumentException("column cannot be null");
        }
        if (dataSource != null && jdbcTemplate == null) {
            jdbcTemplate = new JdbcTemplate(dataSource);
        }
        if (jdbcTemplate == null) {
            throw new IllegalArgumentException("jdbcTemplate cannot be null");
        }
    }

    public static JdbcTemplate getJdbcTemplate() {
        return jdbcTemplate;
    }

    public static void setJdbcTemplate(JdbcTemplate jdbcTemplate) {
        DataBasePartitioner.jdbcTemplate = jdbcTemplate;
    }

    public DataSource getDataSource() {
        return dataSource;
    }

    public void setDataSource(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    public String getTable() {
        return table;
    }

    public void setTable(String table) {
        this.table = table;
    }

    public String getColumn() {
        return column;
    }

    public void setColumn(String column) {
        this.column = column;
    }

    public String getWhereCondition() {
        return whereCondition;
    }

    public void setWhereCondition(String whereCondition) {
        this.whereCondition = whereCondition;
    }
}
